Zbiór danych:
Wykorzystamy zbiór danych medycznych UCI Heart Disease, który zawiera wiek, płeć oraz wyniki badań medycznych pacjenta.
Model:
Jako modele wykorzystany zostanie Random Forest oraz SVC.
import pickle
import dalex as dx
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split
#wczytanie modeli
rf = pickle.load(open("./Modele/random_forest", 'rb'))
svc = pickle.load(open("./Modele/svc2", 'rb'))
# wczytanie zbioru danych
data = pd.read_csv("./heart_data.csv")
# odzielenie targetu od innych zmiennych
y = data.target.values
x = data.drop(['target'], axis = 1)
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size = 0.2,random_state=0, stratify=y)
# stworzenie explainerów
explainer_rf = dx.Explainer(rf, x_train, y_train, label = "Random Forest")
explainer_svc = dx.Explainer(svc, x_train, y_train, label = "SVC")
Preparation of a new explainer is initiated -> data : 242 rows 20 cols -> target variable : 242 values -> model_class : sklearn.ensemble._forest.RandomForestClassifier (default) -> label : Random Forest -> predict function : <function yhat_proba_default at 0x00000202BAF33790> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 0.0322, mean = 0.46, max = 0.988 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.738, mean = -0.00088, max = 0.81 -> model_info : package sklearn A new explainer has been created! Preparation of a new explainer is initiated -> data : 242 rows 20 cols -> target variable : 242 values -> model_class : sklearn.svm._classes.SVC (default) -> label : SVC -> predict function : <function yhat_proba_default at 0x00000202BAF33790> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 0.0525, mean = 0.446, max = 0.982 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.934, mean = 0.0127, max = 0.915 -> model_info : package sklearn A new explainer has been created!
pdp_rf = explainer_rf.model_profile(random_state=0)
pdp_svc = explainer_svc.model_profile(random_state=0)
Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 20/20 [00:19<00:00, 1.01it/s] Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 20/20 [00:02<00:00, 6.70it/s]
#pdp_rf.plot(geom="profiles") <- profile policzone osobno i zaimortowany sam wykres ze względu na zajmowaną pamięć
from IPython.display import Image
Image("PDP.png")
Wnioski:
Przy analizie profilów PDP warto tez sprawdzić jak będą one wyglądały gdy pogrupujemy je po jakiejść zmiennej. Wykorzystamy tutaj sex i ca. Dla lepszej czytelności, wykresy ponizej zostały wykonane tylko dla kilku zmiennych objaśniających.
pdp_rf_sex = explainer_rf.model_profile(random_state=0, groups ='sex')
pdp_rf_sex.plot(color='_groups_', variables = ['age', 'trestbps','chol','fbs','restecg','thalach','oldpeak','ca'])
Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 20/20 [00:23<00:00, 1.16s/it]
Wnisoki:
pdp_rf_ca = explainer_rf.model_profile(random_state=0, groups = 'ca')
pdp_rf_ca.plot(color='_groups_', variables = ['age','trestbps','chol','thalach','exang','oldpeak','thal_n','thal_rd'])
Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 20/20 [00:21<00:00, 1.06s/it]
Wnioski:
pdp_rf.plot([pdp_svc])
Wnioski:
ale_rf = explainer_rf.model_profile(type = 'accumulated', random_state=0)
ale_svc = explainer_svc.model_profile(type = 'accumulated', random_state=0)
Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 20/20 [00:22<00:00, 1.12s/it] Calculating accumulated dependency: 100%|██████████████████████████████████████████████| 20/20 [00:02<00:00, 8.13it/s] Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 20/20 [00:03<00:00, 6.66it/s] Calculating accumulated dependency: 100%|██████████████████████████████████████████████| 20/20 [00:02<00:00, 8.62it/s]
ale_rf.plot(ale_svc)
Wnioski:
ale_rf.result['_label_'] = "ALE"
pdp_rf.result['_label_'] = "PDP"
pdp_rf.plot(ale_rf)
Wnioski: